/*
 * Copyright 2019-2025 JetBrains s.r.o. and contributors.
 * Use of this source code is governed by the Apache 2.0 License that can be found in the LICENSE.txt file.
 */

package kotlinx.datetime.test

import kotlinx.datetime.test.MaliciousJvmSerializationTest.TestCase.Streams
import java.io.ByteArrayInputStream
import java.io.ObjectInputStream
import java.io.ObjectStreamClass
import java.io.Serializable
import kotlin.reflect.KClass
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.fail

class MaliciousJvmSerializationTest {

    /**
     * This data was generated by running the following Java code (`X` was replaced with [clazz]`.simpleName`, `Y` with
     * [delegate]`::class.qualifiedName` and `z` with [delegateFieldName]):
     * ```java
     * package kotlinx.datetime;
     *
     * import java.io.*;
     * import java.util.*;
     *
     * public class X implements Serializable {
     *     private final Y z = ...;
     *
     *     @Serial
     *     private static final long serialVersionUID = ...;
     *
     *     public static void main(String[] args) throws IOException {
     *         var bos = new ByteArrayOutputStream();
     *         try (var oos = new ObjectOutputStream(bos)) {
     *             oos.writeObject(new X());
     *         }
     *         System.out.println(HexFormat.of().formatHex(bos.toByteArray()));
     *     }
     * }
     * ```
     */
    private class TestCase(
        val clazz: KClass<out Serializable>,
        val delegateFieldName: String,
        val delegate: Serializable,
        /** `serialVersionUID` was set to the correct value (`0L`) in the Java code. */
        val withCorrectSVUID: Streams,
        /** `serialVersionUID` was set to an incorrect value (`42L`) in the Java code. */
        val withIncorrectSVUID: Streams,
    ) {
        class Streams(
            /** `z` was set to [delegate] in the Java code. */
            val delegateValid: String,
            /** `z` was set to `null` in the Java code. */
            val delegateNull: String,
        )
    }

    private val testCases = listOf(
        TestCase(
            kotlinx.datetime.LocalDate::class,
            delegateFieldName = "value",
            delegate = java.time.LocalDate.of(2025, 4, 26),
            withCorrectSVUID = Streams(
                delegateValid = "aced00057372001a6b6f746c696e782e6461746574696d652e4c6f63616c4461746500000000000000000200014c000576616c75657400154c6a6176612f74696d652f4c6f63616c446174653b78707372000d6a6176612e74696d652e536572955d84ba1b2248b20c00007870770703000007e9041a78",
                delegateNull = "aced00057372001a6b6f746c696e782e6461746574696d652e4c6f63616c4461746500000000000000000200014c000576616c75657400154c6a6176612f74696d652f4c6f63616c446174653b787070",
            ),
            withIncorrectSVUID = Streams(
                delegateValid = "aced00057372001a6b6f746c696e782e6461746574696d652e4c6f63616c44617465000000000000002a0200014c000576616c75657400154c6a6176612f74696d652f4c6f63616c446174653b78707372000d6a6176612e74696d652e536572955d84ba1b2248b20c00007870770703000007e9041a78",
                delegateNull = "aced00057372001a6b6f746c696e782e6461746574696d652e4c6f63616c44617465000000000000002a0200014c000576616c75657400154c6a6176612f74696d652f4c6f63616c446174653b787070",
            ),
        ),
        TestCase(
            kotlinx.datetime.LocalDateTime::class,
            delegateFieldName = "value",
            delegate = java.time.LocalDateTime.of(2025, 4, 26, 11, 18),
            withCorrectSVUID = Streams(
                delegateValid = "aced00057372001e6b6f746c696e782e6461746574696d652e4c6f63616c4461746554696d6500000000000000000200014c000576616c75657400194c6a6176612f74696d652f4c6f63616c4461746554696d653b78707372000d6a6176612e74696d652e536572955d84ba1b2248b20c00007870770905000007e9041a0bed78",
                delegateNull = "aced00057372001e6b6f746c696e782e6461746574696d652e4c6f63616c4461746554696d6500000000000000000200014c000576616c75657400194c6a6176612f74696d652f4c6f63616c4461746554696d653b787070",
            ),
            withIncorrectSVUID = Streams(
                delegateValid = "aced00057372001e6b6f746c696e782e6461746574696d652e4c6f63616c4461746554696d65000000000000002a0200014c000576616c75657400194c6a6176612f74696d652f4c6f63616c4461746554696d653b78707372000d6a6176612e74696d652e536572955d84ba1b2248b20c00007870770905000007e9041a0bed78",
                delegateNull = "aced00057372001e6b6f746c696e782e6461746574696d652e4c6f63616c4461746554696d65000000000000002a0200014c000576616c75657400194c6a6176612f74696d652f4c6f63616c4461746554696d653b787070",
            ),
        ),
        TestCase(
            kotlinx.datetime.LocalTime::class,
            delegateFieldName = "value",
            delegate = java.time.LocalTime.of(11, 18),
            withCorrectSVUID = Streams(
                delegateValid = "aced00057372001a6b6f746c696e782e6461746574696d652e4c6f63616c54696d6500000000000000000200014c000576616c75657400154c6a6176612f74696d652f4c6f63616c54696d653b78707372000d6a6176612e74696d652e536572955d84ba1b2248b20c000078707703040bed78",
                delegateNull = "aced00057372001a6b6f746c696e782e6461746574696d652e4c6f63616c54696d6500000000000000000200014c000576616c75657400154c6a6176612f74696d652f4c6f63616c54696d653b787070",
            ),
            withIncorrectSVUID = Streams(
                delegateValid = "aced00057372001a6b6f746c696e782e6461746574696d652e4c6f63616c54696d65000000000000002a0200014c000576616c75657400154c6a6176612f74696d652f4c6f63616c54696d653b78707372000d6a6176612e74696d652e536572955d84ba1b2248b20c000078707703040bed78",
                delegateNull = "aced00057372001a6b6f746c696e782e6461746574696d652e4c6f63616c54696d65000000000000002a0200014c000576616c75657400154c6a6176612f74696d652f4c6f63616c54696d653b787070",
            ),
        ),
        TestCase(
            kotlinx.datetime.UtcOffset::class,
            delegateFieldName = "zoneOffset",
            delegate = java.time.ZoneOffset.UTC,
            withCorrectSVUID = Streams(
                delegateValid = "aced00057372001a6b6f746c696e782e6461746574696d652e5574634f666673657400000000000000000200014c000a7a6f6e654f66667365747400164c6a6176612f74696d652f5a6f6e654f66667365743b78707372000d6a6176612e74696d652e536572955d84ba1b2248b20c000078707702080078",
                delegateNull = "aced00057372001a6b6f746c696e782e6461746574696d652e5574634f666673657400000000000000000200014c000a7a6f6e654f66667365747400164c6a6176612f74696d652f5a6f6e654f66667365743b787070",
            ),
            withIncorrectSVUID = Streams(
                delegateValid = "aced00057372001a6b6f746c696e782e6461746574696d652e5574634f6666736574000000000000002a0200014c000a7a6f6e654f66667365747400164c6a6176612f74696d652f5a6f6e654f66667365743b78707372000d6a6176612e74696d652e536572955d84ba1b2248b20c000078707702080078",
                delegateNull = "aced00057372001a6b6f746c696e782e6461746574696d652e5574634f6666736574000000000000002a0200014c000a7a6f6e654f66667365747400164c6a6176612f74696d652f5a6f6e654f66667365743b787070",
            ),
        ),
        TestCase(
            kotlinx.datetime.YearMonth::class,
            delegateFieldName = "value",
            delegate = java.time.YearMonth.of(2025, 4),
            withCorrectSVUID = Streams(
                delegateValid = "aced00057372001a6b6f746c696e782e6461746574696d652e596561724d6f6e746800000000000000000200014c000576616c75657400154c6a6176612f74696d652f596561724d6f6e74683b78707372000d6a6176612e74696d652e536572955d84ba1b2248b20c0000787077060c000007e90478",
                delegateNull = "aced00057372001a6b6f746c696e782e6461746574696d652e596561724d6f6e746800000000000000000200014c000576616c75657400154c6a6176612f74696d652f596561724d6f6e74683b787070",
            ),
            withIncorrectSVUID = Streams(
                delegateValid = "aced00057372001a6b6f746c696e782e6461746574696d652e596561724d6f6e7468000000000000002a0200014c000576616c75657400154c6a6176612f74696d652f596561724d6f6e74683b78707372000d6a6176612e74696d652e536572955d84ba1b2248b20c0000787077060c000007e90478",
                delegateNull = "aced00057372001a6b6f746c696e782e6461746574696d652e596561724d6f6e7468000000000000002a0200014c000576616c75657400154c6a6176612f74696d652f596561724d6f6e74683b787070",
            ),
        ),
    )

    @OptIn(ExperimentalStdlibApi::class)
    private fun deserialize(stream: String): Any? {
        val bis = ByteArrayInputStream(stream.hexToByteArray())
        return ObjectInputStream(bis).use { ois ->
            ois.readObject()
        }
    }

    @Test
    fun deserializeMaliciousStreams() {
        for (testCase in testCases) {
            testCase.ensureAssumptionsHold()
            val className = testCase.clazz.qualifiedName!!
            testStreamsWithCorrectSVUID(className, testCase.withCorrectSVUID)
            testStreamsWithIncorrectSVUID(className, testCase.withIncorrectSVUID)
        }
    }

    private fun TestCase.ensureAssumptionsHold() {
        val className = clazz.qualifiedName!!
        val objectStreamClass = ObjectStreamClass.lookup(clazz.java)

        val actualSerialVersionUID = objectStreamClass.serialVersionUID
        if (actualSerialVersionUID != 0L) {
            fail("This test assumes that the serialVersionUID of $className is 0, but it was $actualSerialVersionUID.")
        }

        val field = objectStreamClass.fields.singleOrNull()
        if (field == null || field.name != delegateFieldName || field.type != delegate.javaClass) {
            fail(
                "This test assumes that $className has a single serializable field named '$delegateFieldName' of " +
                    "type ${delegate::class.qualifiedName}. The test case for $className should be updated with new " +
                    "malicious serial streams that represent the changes to $className."
            )
        }
    }

    private fun testStreamsWithCorrectSVUID(className: String, streams: Streams) {
        val testFailureMessage = "Deserialization of a serial stream that tries to bypass kotlinx.datetime.Ser and " +
            "has the correct serialVersionUID for $className should fail"

        val expectedIOEMessage = "$className must be deserialized via kotlinx.datetime.Ser"

        // this would actually create a valid instance, but serialization should always go through the proxy
        val ioe1 = assertFailsWith<java.io.InvalidObjectException>(testFailureMessage) {
            deserialize(streams.delegateValid)
        }
        assertEquals(expectedIOEMessage, ioe1.message)

        // this would create an instance that has null in a non-nullable field (e.g., the field
        // kotlinx.datetime.LocalDate.value)
        // see https://github.com/Kotlin/kotlinx-datetime/pull/373#discussion_r2008922681
        val ioe2 = assertFailsWith<java.io.InvalidObjectException>(testFailureMessage) {
            deserialize(streams.delegateNull)
        }
        assertEquals(expectedIOEMessage, ioe2.message)
    }

    private fun testStreamsWithIncorrectSVUID(className: String, streams: Streams) {
        val testFailureMessage = "Deserialization of a serial stream that tries to bypass kotlinx.datetime.Ser but " +
            "has a wrong serialVersionUID for $className should fail"

        val expectedICEMessage = "$className; local class incompatible: stream classdesc serialVersionUID = 42, " +
            "local class serialVersionUID = 0"

        val ice1 = assertFailsWith<java.io.InvalidClassException>(testFailureMessage) {
            deserialize(streams.delegateValid)
        }
        assertEquals(expectedICEMessage, ice1.message)

        val ice2 = assertFailsWith<java.io.InvalidClassException>(testFailureMessage) {
            deserialize(streams.delegateNull)
        }
        assertEquals(expectedICEMessage, ice2.message)
    }
}
